-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[mlir][mesh,mpi] Lower allreduce #144060
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][mesh,mpi] Lower allreduce #144060
Conversation
|
@llvm/pr-subscribers-mlir Author: Frank Schlimbach (fschlimb) ChangesAdding lowering mesh.allreduce to mpi.allreduce. Patch is 63.75 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/144060.diff 12 Files Affected:
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index b496ee0114910..5a864865adffc 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -905,6 +905,8 @@ def ConvertMeshToMPIPass : Pass<"convert-mesh-to-mpi"> {
shard/partition sizes depend on the rank.
}];
let dependentDialects = [
+ "affine::AffineDialect",
+ "arith::ArithDialect",
"memref::MemRefDialect",
"mpi::MPIDialect",
"scf::SCFDialect",
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPI.h b/mlir/include/mlir/Dialect/MPI/IR/MPI.h
index f06b911ce3fe3..2b6743cd008c6 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPI.h
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPI.h
@@ -12,6 +12,7 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
//===----------------------------------------------------------------------===//
// MPIDialect
diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
index d78aa92d201e7..c14837f6961eb 100644
--- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
+++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
@@ -11,6 +11,7 @@
include "mlir/Dialect/MPI/IR/MPI.td"
include "mlir/Dialect/MPI/IR/MPITypes.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
class MPI_Op<string mnemonic, list<Trait> traits = []>
: Op<MPI_Dialect, mnemonic, traits>;
@@ -41,7 +42,7 @@ def MPI_InitOp : MPI_Op<"init", []> {
// CommWorldOp
//===----------------------------------------------------------------------===//
-def MPI_CommWorldOp : MPI_Op<"comm_world", []> {
+def MPI_CommWorldOp : MPI_Op<"comm_world", [Pure]> {
let summary = "Get the World communicator, equivalent to `MPI_COMM_WORLD`";
let description = [{
This operation returns the predefined MPI_COMM_WORLD communicator.
@@ -56,7 +57,7 @@ def MPI_CommWorldOp : MPI_Op<"comm_world", []> {
// CommRankOp
//===----------------------------------------------------------------------===//
-def MPI_CommRankOp : MPI_Op<"comm_rank", []> {
+def MPI_CommRankOp : MPI_Op<"comm_rank", [Pure]> {
let summary = "Get the current rank, equivalent to "
"`MPI_Comm_rank(comm, &rank)`";
let description = [{
@@ -72,13 +73,14 @@ def MPI_CommRankOp : MPI_Op<"comm_rank", []> {
);
let assemblyFormat = "`(` $comm `)` attr-dict `:` type(results)";
+ let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
// CommSizeOp
//===----------------------------------------------------------------------===//
-def MPI_CommSizeOp : MPI_Op<"comm_size", []> {
+def MPI_CommSizeOp : MPI_Op<"comm_size", [Pure]> {
let summary = "Get the size of the group associated to the communicator, "
"equivalent to `MPI_Comm_size(comm, &size)`";
let description = [{
@@ -100,7 +102,7 @@ def MPI_CommSizeOp : MPI_Op<"comm_size", []> {
// CommSplitOp
//===----------------------------------------------------------------------===//
-def MPI_CommSplitOp : MPI_Op<"comm_split", []> {
+def MPI_CommSplitOp : MPI_Op<"comm_split", [Pure]> {
let summary = "Partition the group associated with the given communicator into "
"disjoint subgroups";
let description = [{
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index f59c4c4c67517..ac05ee243d7be 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -584,11 +584,11 @@ def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
```
}];
let arguments = !con(commonArgs, (ins
- AnyRankedTensor:$input,
+ AnyTypeOf<[AnyMemRef, AnyRankedTensor]>:$input,
DefaultValuedAttr<Mesh_ReductionKindAttr, "::mlir::mesh::ReductionKind::Sum">:$reduction
));
let results = (outs
- AnyRankedTensor:$result
+ AnyTypeOf<[AnyMemRef, AnyRankedTensor]>:$result
);
let assemblyFormat = [{
$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? (`reduction` `=` $reduction^)?
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
index c64da29ca6412..3f1041cb25103 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
@@ -62,9 +62,9 @@ void populateAllReduceEndomorphismSimplificationPatterns(
auto isEndomorphismOp = [reduction](Operation *op,
std::optional<Operation *> referenceOp) {
auto allReduceOp = llvm::dyn_cast<AllReduceOp>(op);
- if (!allReduceOp ||
- allReduceOp.getInput().getType().getElementType() !=
- allReduceOp.getResult().getType().getElementType() ||
+ auto inType = cast<ShapedType>(allReduceOp.getInput().getType());
+ auto outType = cast<ShapedType>(allReduceOp.getResult().getType());
+ if (!allReduceOp || inType.getElementType() != outType.getElementType() ||
allReduceOp.getReduction() != reduction) {
return false;
}
@@ -83,9 +83,9 @@ void populateAllReduceEndomorphismSimplificationPatterns(
}
auto refAllReduceOp = llvm::dyn_cast<AllReduceOp>(referenceOp.value());
+ auto refType = cast<ShapedType>(refAllReduceOp.getResult().getType());
return refAllReduceOp->getAttrs() == allReduceOp->getAttrs() &&
- allReduceOp.getInput().getType().getElementType() ==
- refAllReduceOp.getInput().getType().getElementType();
+ inType.getElementType() == refType.getElementType();
};
auto isAlgebraicOp = [](Operation *op) {
return static_cast<bool>(llvm::dyn_cast<AlgebraicOp>(op));
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
index be82e2af399dc..5a1154bf9166e 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
@@ -42,6 +42,10 @@ createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef<MeshAxis> axes,
TypedValue<IndexType> createProcessLinearIndex(StringRef mesh,
ArrayRef<MeshAxis> meshAxes,
ImplicitLocOpBuilder &builder);
+TypedValue<IndexType>
+createProcessLinearIndex(StringRef mesh, ValueRange processInGroupMultiIndex,
+ ArrayRef<MeshAxis> meshAxes,
+ ImplicitLocOpBuilder &builder);
} // namespace mesh
} // namespace mlir
diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index 823d4d644f586..521569e69b61a 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -12,9 +12,9 @@
#include "mlir/Conversion/MeshToMPI/MeshToMPI.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
-#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -22,6 +22,8 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
+#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
@@ -289,27 +291,15 @@ struct ConvertProcessMultiIndexOp
class ConvertProcessLinearIndexOp
: public OpConversionPattern<ProcessLinearIndexOp> {
- int64_t worldRank; // rank in MPI_COMM_WORLD if available, else < 0
public:
using OpConversionPattern::OpConversionPattern;
- // Constructor accepting worldRank
- ConvertProcessLinearIndexOp(const TypeConverter &typeConverter,
- MLIRContext *context, int64_t worldRank = -1)
- : OpConversionPattern(typeConverter, context), worldRank(worldRank) {}
-
LogicalResult
matchAndRewrite(ProcessLinearIndexOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
-
+ // Create mpi::CommRankOp
Location loc = op.getLoc();
- if (worldRank >= 0) { // if rank in MPI_COMM_WORLD is known -> use it
- rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, worldRank);
- return success();
- }
-
- // Otherwise call create mpi::CommRankOp
auto ctx = op.getContext();
Value commWorld =
rewriter.create<mpi::CommWorldOp>(loc, mpi::CommType::get(ctx));
@@ -529,6 +519,124 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
}
};
+static mpi::MPI_OpClassEnumAttr getMPIReduction(ReductionKindAttr kind) {
+ auto ctx = kind.getContext();
+ switch (kind.getValue()) {
+ case ReductionKind::Sum:
+ return mpi::MPI_OpClassEnumAttr::get(ctx, mpi::MPI_OpClassEnum::MPI_SUM);
+ case ReductionKind::Product:
+ return mpi::MPI_OpClassEnumAttr::get(ctx, mpi::MPI_OpClassEnum::MPI_PROD);
+ case ReductionKind::Min:
+ return mpi::MPI_OpClassEnumAttr::get(ctx, mpi::MPI_OpClassEnum::MPI_MIN);
+ case ReductionKind::Max:
+ return mpi::MPI_OpClassEnumAttr::get(ctx, mpi::MPI_OpClassEnum::MPI_MAX);
+ case ReductionKind::BitwiseAnd:
+ return mpi::MPI_OpClassEnumAttr::get(ctx, mpi::MPI_OpClassEnum::MPI_BAND);
+ case ReductionKind::BitwiseOr:
+ return mpi::MPI_OpClassEnumAttr::get(ctx, mpi::MPI_OpClassEnum::MPI_BOR);
+ case ReductionKind::BitwiseXor:
+ return mpi::MPI_OpClassEnumAttr::get(ctx, mpi::MPI_OpClassEnum::MPI_BXOR);
+ default:
+ assert(false && "Unknown/unsupported reduction kind");
+ }
+}
+
+struct ConvertAllReduceOp : public OpConversionPattern<AllReduceOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(AllReduceOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ SymbolTableCollection symbolTableCollection;
+ auto mesh = adaptor.getMesh();
+ auto meshOp = getMesh(op, symbolTableCollection);
+ if (!meshOp)
+ return op->emitError() << "No mesh found for AllReduceOp";
+ if (ShapedType::isDynamicShape(meshOp.getShape()))
+ return op->emitError()
+ << "Dynamic mesh shape not supported in AllReduceOp";
+
+ ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter);
+ Value input = adaptor.getInput();
+ auto inputShape = cast<ShapedType>(input.getType()).getShape();
+
+ // If the source is a memref, cast it to a tensor.
+ if (isa<RankedTensorType>(input.getType())) {
+ auto memrefType = MemRefType::get(
+ inputShape, cast<ShapedType>(input.getType()).getElementType());
+ input = iBuilder.create<bufferization::ToMemrefOp>(memrefType, input);
+ }
+ MemRefType inType = cast<MemRefType>(input.getType());
+
+ // Get the actual shape to allocate the buffer.
+ SmallVector<OpFoldResult> shape(inType.getRank());
+ for (auto i = 0; i < inType.getRank(); ++i) {
+ auto s = inputShape[i];
+ if (ShapedType::isDynamic(s))
+ shape[i] = iBuilder.create<memref::DimOp>(input, s).getResult();
+ else
+ shape[i] = iBuilder.getIndexAttr(s);
+ }
+
+ // Allocate buffer and copy input to buffer.
+ Value buffer = iBuilder.create<memref::AllocOp>(
+ shape, cast<ShapedType>(op.getType()).getElementType());
+ iBuilder.create<linalg::CopyOp>(input, buffer);
+
+ // Get an MPI_Comm_split for the AllReduce operation.
+ // The color is the linear index of the process in the mesh along the
+ // non-reduced axes. The key is the linear index of the process in the mesh
+ // along the reduced axes.
+ SmallVector<Type> indexResultTypes(meshOp.getShape().size(),
+ iBuilder.getIndexType());
+ SmallVector<Value> myMultiIndex =
+ iBuilder.create<ProcessMultiIndexOp>(indexResultTypes, mesh)
+ .getResult();
+ Value zero = iBuilder.create<arith::ConstantIndexOp>(0);
+ SmallVector<Value> multiKey(myMultiIndex.size(), zero);
+
+ auto redAxes = adaptor.getMeshAxes();
+ for (auto axis : redAxes) {
+ multiKey[axis] = myMultiIndex[axis];
+ myMultiIndex[axis] = zero;
+ }
+
+ Value color =
+ createProcessLinearIndex(mesh, myMultiIndex, redAxes, iBuilder);
+ color = iBuilder.create<arith::IndexCastOp>(iBuilder.getI32Type(), color);
+ Value key = createProcessLinearIndex(mesh, multiKey, redAxes, iBuilder);
+ key = iBuilder.create<arith::IndexCastOp>(iBuilder.getI32Type(), key);
+
+ // Finally split the communicator
+ auto commType = mpi::CommType::get(op->getContext());
+ Value commWorld = iBuilder.create<mpi::CommWorldOp>(commType);
+ auto comm =
+ iBuilder.create<mpi::CommSplitOp>(commType, commWorld, color, key)
+ .getNewcomm();
+
+ Value buffer1d = buffer;
+ // Collapse shape to 1d if needed
+ if (inType.getRank() > 1) {
+ ReassociationIndices reassociation(inType.getRank());
+ std::iota(reassociation.begin(), reassociation.end(), 0);
+ buffer1d = iBuilder.create<memref::CollapseShapeOp>(
+ buffer, ArrayRef<ReassociationIndices>(reassociation));
+ }
+
+ // Create the MPI AllReduce operation.
+ iBuilder.create<mpi::AllReduceOp>(
+ TypeRange(), buffer1d, buffer1d,
+ getMPIReduction(adaptor.getReductionAttr()), comm);
+
+ // If the destination is a memref, cast it to a tensor
+ if (isa<RankedTensorType>(op.getType()))
+ buffer = iBuilder.create<bufferization::ToTensorOp>(buffer, true);
+
+ rewriter.replaceOp(op, buffer);
+ return success();
+ }
+};
+
struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
using OpConversionPattern::OpConversionPattern;
@@ -573,10 +681,10 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
Value array = dest;
if (isa<RankedTensorType>(array.getType())) {
// If the destination is a memref, we need to cast it to a tensor
- auto tensorType = MemRefType::get(
+ auto mmemrefType = MemRefType::get(
dstShape, cast<ShapedType>(array.getType()).getElementType());
array =
- rewriter.create<bufferization::ToBufferOp>(loc, tensorType, array);
+ rewriter.create<bufferization::ToMemrefOp>(loc, mmemrefType, array);
}
auto rank = cast<ShapedType>(array.getType()).getRank();
auto opSplitAxes = adaptor.getSplitAxes().getAxes();
@@ -753,22 +861,6 @@ struct ConvertMeshToMPIPass
/// Run the dialect converter on the module.
void runOnOperation() override {
- uint64_t worldRank = -1;
- // Try to get DLTI attribute for MPI:comm_world_rank
- // If found, set worldRank to the value of the attribute.
- {
- auto dltiAttr =
- dlti::query(getOperation(), {"MPI:comm_world_rank"}, false);
- if (succeeded(dltiAttr)) {
- if (!isa<IntegerAttr>(dltiAttr.value())) {
- getOperation()->emitError()
- << "Expected an integer attribute for MPI:comm_world_rank";
- return signalPassFailure();
- }
- worldRank = cast<IntegerAttr>(dltiAttr.value()).getInt();
- }
- }
-
auto *ctxt = &getContext();
RewritePatternSet patterns(ctxt);
ConversionTarget target(getContext());
@@ -819,10 +911,10 @@ struct ConvertMeshToMPIPass
// ...except the global MeshOp
target.addLegalOp<mesh::MeshOp>();
// Allow all the stuff that our patterns will convert to
- target.addLegalDialect<BuiltinDialect, mpi::MPIDialect, scf::SCFDialect,
- arith::ArithDialect, tensor::TensorDialect,
- bufferization::BufferizationDialect,
- linalg::LinalgDialect, memref::MemRefDialect>();
+ target.addLegalDialect<
+ BuiltinDialect, mpi::MPIDialect, scf::SCFDialect, arith::ArithDialect,
+ tensor::TensorDialect, bufferization::BufferizationDialect,
+ linalg::LinalgDialect, memref::MemRefDialect, affine::AffineDialect>();
// Make sure the function signature, calls etc. are legal
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
return typeConverter.isSignatureLegal(op.getFunctionType());
@@ -832,9 +924,10 @@ struct ConvertMeshToMPIPass
patterns.add<ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp,
ConvertProcessMultiIndexOp, ConvertGetShardingOp,
- ConvertShardingOp, ConvertShardShapeOp>(typeConverter, ctxt);
- // ConvertProcessLinearIndexOp accepts an optional worldRank
- patterns.add<ConvertProcessLinearIndexOp>(typeConverter, ctxt, worldRank);
+ ConvertShardingOp, ConvertShardShapeOp, ConvertAllReduceOp,
+ ConvertProcessLinearIndexOp>(typeConverter, ctxt);
+ SymbolTableCollection symbolTableCollection;
+ mlir::mesh::populateFoldingPatterns(patterns, symbolTableCollection);
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
patterns, typeConverter);
diff --git a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
index 56d8edfbcc025..6d445ca0e4099 100644
--- a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
+++ b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp
@@ -6,6 +6,7 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/MPI/IR/MPI.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Builders.h"
@@ -41,6 +42,38 @@ struct FoldCast final : public mlir::OpRewritePattern<OpT> {
return mlir::success();
}
};
+
+struct FoldRank final : public mlir::OpRewritePattern<mlir::mpi::CommRankOp> {
+ using mlir::OpRewritePattern<mlir::mpi::CommRankOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(mlir::mpi::CommRankOp op,
+ mlir::PatternRewriter &b) const override {
+ auto comm = op.getComm();
+ if (!comm.getDefiningOp<mlir::mpi::CommWorldOp>()) {
+ return mlir::failure();
+ }
+
+ // Try to get DLTI attribute for MPI:comm_world_rank
+ // If found, set worldRank to the value of the attribute.
+ {
+ auto dltiAttr = dlti::query(op, {"MPI:comm_world_rank"}, false);
+ if (failed(dltiAttr))
+ return mlir::failure();
+ if (!isa<IntegerAttr>(dltiAttr.value())) {
+ return op->emitError()
+ << "Expected an integer attribute for MPI:comm_world_rank";
+ }
+ Value res = b.create<arith::ConstantIndexOp>(
+ op.getLoc(), cast<IntegerAttr>(dltiAttr.value()).getInt());
+ if (Value retVal = op.getRetval())
+ b.replaceOp(op, {retVal, res});
+ else
+ b.replaceOp(op, res);
+ return mlir::success();
+ }
+ }
+};
+
} // namespace
void mlir::mpi::SendOp::getCanonicalizationPatterns(
@@ -63,6 +96,11 @@ void mlir::mpi::IRecvOp::getCanonicalizationPatterns(
results.add<FoldCast<mlir::mpi::IRecvOp>>(context);
}
+void mlir::mpi::CommRankOp::getCanonicalizationPatterns(
+ mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
+ results.add<FoldRank>(context);
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 304cb55a35086..b84de2b716b32 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -75,6 +75,31 @@ static DimensionSize operator*(DimensionSize lhs, DimensionSize rhs) {
return lhs.value() * rhs.value();
}
+/// Converts a vector of OpFoldResults (ints) into vector of Values of the
+/// provided type.
+SmallVector<Value> mlir::mesh::getMixedAsValues(OpBuilder b,
+ const Location &loc,
+ llvm::ArrayRef<int64_t> statics,
+ ValueRange dynamics,
+ Type ...
[truncated]
|
|
FYI @BenBrock |
b07686c to
1f65423
Compare
Dinistro
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dropped a bunch of comments. The main concern is mixing conversion and non-conversion patterns, which is broken in the general case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Drop the default. Clang and GCC complain about uncovered cases at compile time, which they cannot do when there is a default case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll keep it. Compiler checks can be disabled.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Consider to factor the enum conversion into a separate function, do avoid duplicating the attribute construction this often.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll keep it like this. Extra work introducing an indirection just to make the lines a few characters doesn't look right to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Use a more descriptive name.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
like what?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
getMPIReductionAttr? getMPIReductionOpAttr?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Is this unused?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is used in meshtompi.cpp. Another PR will add a use elsewhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if we can cut some of the check lines a bit. They contain tons of uninteresting type information that is only a pain to maintain but gives almost no benefits.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, maybe here and there, but we'd need to keep some anyway, to check that collapsing works the right way etc.
I'll leave it as-is for now. If it becomes an issue, we can re-visit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using a normal builder is not legal, you need to use the provided rewriter for all IR manipulations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I "copied" this type of use from ComplexToStandard. Otherwise, how would I be able to use these helper functions like indexResultTypes?
1d4cc8f to
db7fa00
Compare
tkarna
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
getMPIReductionAttr? getMPIReductionOpAttr?
1723c59 to
45b837f
Compare
You can test this locally with the following command:git-clang-format --diff HEAD~1 HEAD --extensions h,cpp -- mlir/include/mlir/Dialect/MPI/IR/MPI.h mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp mlir/lib/Dialect/MPI/IR/MPIOps.cpp mlir/lib/Dialect/Mesh/IR/MeshOps.cpp mlir/lib/Dialect/Mesh/Transforms/Transforms.cppView the diff from clang-format here.diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
index aaf1d39d4..bbae8f32b 100644
--- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
+++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
@@ -523,19 +523,26 @@ static mpi::MPI_ReductionOpEnumAttr getMPIReductionOp(ReductionKindAttr kind) {
auto ctx = kind.getContext();
switch (kind.getValue()) {
case ReductionKind::Sum:
- return mpi::MPI_ReductionOpEnumAttr::get(ctx, mpi::MPI_ReductionOpEnum::MPI_SUM);
+ return mpi::MPI_ReductionOpEnumAttr::get(ctx,
+ mpi::MPI_ReductionOpEnum::MPI_SUM);
case ReductionKind::Product:
- return mpi::MPI_ReductionOpEnumAttr::get(ctx, mpi::MPI_ReductionOpEnum::MPI_PROD);
+ return mpi::MPI_ReductionOpEnumAttr::get(
+ ctx, mpi::MPI_ReductionOpEnum::MPI_PROD);
case ReductionKind::Min:
- return mpi::MPI_ReductionOpEnumAttr::get(ctx, mpi::MPI_ReductionOpEnum::MPI_MIN);
+ return mpi::MPI_ReductionOpEnumAttr::get(ctx,
+ mpi::MPI_ReductionOpEnum::MPI_MIN);
case ReductionKind::Max:
- return mpi::MPI_ReductionOpEnumAttr::get(ctx, mpi::MPI_ReductionOpEnum::MPI_MAX);
+ return mpi::MPI_ReductionOpEnumAttr::get(ctx,
+ mpi::MPI_ReductionOpEnum::MPI_MAX);
case ReductionKind::BitwiseAnd:
- return mpi::MPI_ReductionOpEnumAttr::get(ctx, mpi::MPI_ReductionOpEnum::MPI_BAND);
+ return mpi::MPI_ReductionOpEnumAttr::get(
+ ctx, mpi::MPI_ReductionOpEnum::MPI_BAND);
case ReductionKind::BitwiseOr:
- return mpi::MPI_ReductionOpEnumAttr::get(ctx, mpi::MPI_ReductionOpEnum::MPI_BOR);
+ return mpi::MPI_ReductionOpEnumAttr::get(ctx,
+ mpi::MPI_ReductionOpEnum::MPI_BOR);
case ReductionKind::BitwiseXor:
- return mpi::MPI_ReductionOpEnumAttr::get(ctx, mpi::MPI_ReductionOpEnum::MPI_BXOR);
+ return mpi::MPI_ReductionOpEnumAttr::get(
+ ctx, mpi::MPI_ReductionOpEnum::MPI_BXOR);
default:
assert(false && "Unknown/unsupported reduction kind");
}
|
Adding lowering mesh.allreduce to mpi.allreduce.
Minor restructuring to increase code reuse.